/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.seatunnel.connectors.seatunnel.gaussdbcassandra.serialize;

import com.datastax.oss.driver.api.core.cql.BoundStatement;
import com.datastax.oss.driver.api.core.cql.ColumnDefinitions;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.internal.core.type.DefaultListType;
import com.datastax.oss.driver.internal.core.type.DefaultMapType;
import com.datastax.oss.driver.internal.core.type.DefaultSetType;
import com.datastax.oss.protocol.internal.ProtocolConstants;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.seatunnel.api.table.type.*;
import org.apache.seatunnel.common.exception.CommonErrorCode;
import org.apache.seatunnel.connectors.seatunnel.gaussdbcassandra.exception.GaussDBCassandraErrorCode;
import org.apache.seatunnel.connectors.seatunnel.gaussdbcassandra.exception.GaussDBConnectorException;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.sql.Timestamp;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.ZoneId;
import java.util.*;

public class TypeConvertUtil {

    /**
     * The main function of this code is to convert the data type in Cassandra into
     * the
     * corresponding SeaTunnel data type for use in SeaTunnel.
     *
     * @param type
     * @return
     */
    public static SeaTunnelDataType<?> convert(DataType type) {
        switch (type.getProtocolCode()) {
            case ProtocolConstants.DataType.VARCHAR:
            case ProtocolConstants.DataType.VARINT:
            case ProtocolConstants.DataType.ASCII:
            case ProtocolConstants.DataType.UUID:
            case ProtocolConstants.DataType.INET:
            case ProtocolConstants.DataType.TIMEUUID:
                return BasicType.STRING_TYPE;
            case ProtocolConstants.DataType.TINYINT:
                return BasicType.BYTE_TYPE;
            case ProtocolConstants.DataType.SMALLINT:
                return BasicType.SHORT_TYPE;
            case ProtocolConstants.DataType.INT:
                return BasicType.INT_TYPE;
            case ProtocolConstants.DataType.BIGINT:
            case ProtocolConstants.DataType.COUNTER:
                return BasicType.LONG_TYPE;
            case ProtocolConstants.DataType.FLOAT:
                return BasicType.FLOAT_TYPE;
            case ProtocolConstants.DataType.DOUBLE:
            case ProtocolConstants.DataType.DECIMAL:
                return BasicType.DOUBLE_TYPE;
            case ProtocolConstants.DataType.BOOLEAN:
                return BasicType.BOOLEAN_TYPE;
            case ProtocolConstants.DataType.TIME:
                return LocalTimeType.LOCAL_TIME_TYPE;
            case ProtocolConstants.DataType.DATE:
                return LocalTimeType.LOCAL_DATE_TYPE;
            case ProtocolConstants.DataType.TIMESTAMP:
                return LocalTimeType.LOCAL_DATE_TIME_TYPE;
            case ProtocolConstants.DataType.BLOB:
                return ArrayType.BYTE_ARRAY_TYPE;
            case ProtocolConstants.DataType.MAP:
                if (type instanceof DefaultMapType)
                    return new MapType<>(convert(((DefaultMapType) type).getKeyType()),
                            convert(((DefaultMapType) type).getValueType()));
            case ProtocolConstants.DataType.LIST:
                if (type instanceof DefaultListType)
                    return convertToArrayType(convert(((DefaultListType) type).getElementType()));
            case ProtocolConstants.DataType.SET:
                if (type instanceof DefaultSetType)
                    return convertToArrayType(convert(((DefaultSetType) type).getElementType()));
            default:
                throw new GaussDBConnectorException(CommonErrorCode.UNSUPPORTED_DATA_TYPE,
                        "Unsupported this data type: " + type);
        }
    }

    private static ArrayType<?, ?> convertToArrayType(SeaTunnelDataType<?> dataType) {
        if (BasicType.STRING_TYPE.equals(dataType)) {
            return ArrayType.STRING_ARRAY_TYPE;
        } else if (BasicType.BYTE_TYPE.equals(dataType)) {
            return ArrayType.BYTE_ARRAY_TYPE;
        } else if (BasicType.SHORT_TYPE.equals(dataType)) {
            return ArrayType.SHORT_ARRAY_TYPE;
        } else if (BasicType.INT_TYPE.equals(dataType)) {
            return ArrayType.INT_ARRAY_TYPE;
        } else if (BasicType.LONG_TYPE.equals(dataType)) {
            return ArrayType.LONG_ARRAY_TYPE;
        } else if (BasicType.FLOAT_TYPE.equals(dataType)) {
            return ArrayType.FLOAT_ARRAY_TYPE;
        } else if (BasicType.DOUBLE_TYPE.equals(dataType)) {
            return ArrayType.DOUBLE_ARRAY_TYPE;
        } else if (BasicType.BOOLEAN_TYPE.equals(dataType)) {
            return ArrayType.BOOLEAN_ARRAY_TYPE;
        } else {
            throw new GaussDBConnectorException(CommonErrorCode.UNSUPPORTED_DATA_TYPE,
                    "Unsupported this data type: " + dataType);
        }
    }

    /**
     * By traversing each column of the Row object, it is converted according to the
     * data type of
     * the column, and the converted value is stored in the fields array. new
     * SeaTunnelRow(fields)
     * returns the SeaTunnelRow object.
     *
     * @param row
     * @return
     */
    public static SeaTunnelRow buildSeaTunnelRow(Row row) {
        DataType subType;
        Class<?> typeClass;
        Object[] fields = new Object[row.size()];
        ColumnDefinitions metaData = row.getColumnDefinitions();
        for (int i = 0; i < row.size(); i++) {
            switch (metaData.get(i).getType().getProtocolCode()) {
                case ProtocolConstants.DataType.ASCII:
                case ProtocolConstants.DataType.VARCHAR:
                    fields[i] = row.getString(i);
                    break;
                case ProtocolConstants.DataType.VARINT:
                    fields[i] = Objects.requireNonNull(row.getBigInteger(i)).toString();
                    break;
                case ProtocolConstants.DataType.TIMEUUID:
                case ProtocolConstants.DataType.UUID:
                    fields[i] = Objects.requireNonNull(row.getUuid(i)).toString();
                    break;
                case ProtocolConstants.DataType.INET:
                    fields[i] = Objects.requireNonNull(row.getInetAddress(i)).getHostAddress();
                    break;
                case ProtocolConstants.DataType.TINYINT:
                    fields[i] = row.getByte(i);
                    break;
                case ProtocolConstants.DataType.SMALLINT:
                    fields[i] = row.getShort(i);
                    break;
                case ProtocolConstants.DataType.INT:
                    fields[i] = row.getInt(i);
                    break;
                case ProtocolConstants.DataType.BIGINT:
                    fields[i] = row.getLong(i);
                    break;
                case ProtocolConstants.DataType.FLOAT:
                    fields[i] = row.getFloat(i);
                    break;
                case ProtocolConstants.DataType.DOUBLE:
                    fields[i] = row.getDouble(i);
                    break;
                case ProtocolConstants.DataType.DECIMAL:
                    fields[i] = Objects.requireNonNull(row.getBigDecimal(i)).doubleValue();
                    break;
                case ProtocolConstants.DataType.BOOLEAN:
                    fields[i] = row.getBoolean(i);
                    break;
                case ProtocolConstants.DataType.TIME:
                    fields[i] = row.getLocalTime(i);
                    break;
                case ProtocolConstants.DataType.DATE:
                    fields[i] = row.getLocalDate(i);
                    break;
                case ProtocolConstants.DataType.TIMESTAMP:
                    fields[i] = Timestamp.from(Objects.requireNonNull(row.getInstant(i))).toLocalDateTime();
                    break;
                case ProtocolConstants.DataType.BLOB:
                    fields[i] = ArrayUtils.toObject(Objects.requireNonNull(row.getByteBuffer(i)).array());
                    break;
                case ProtocolConstants.DataType.MAP:
                    subType = metaData.get(i).getType();
                    if (subType instanceof DefaultMapType)
                        fields[i] = row.getMap(i, convert(((DefaultMapType) subType).getKeyType()).getTypeClass(),
                                convert(((DefaultMapType) subType).getValueType()).getTypeClass());
                    break;
                case ProtocolConstants.DataType.LIST:
                    typeClass = null;
                    if (metaData.get(i).getType() instanceof DefaultListType)
                        typeClass = convert(((DefaultListType) metaData.get(i).getType()).getElementType())
                                .getTypeClass();
                    if (String.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getList(i, String.class)).toArray(new String[0]);
                    } else if (Byte.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getList(i, Byte.class)).toArray(new Byte[0]);
                    } else if (Short.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getList(i, Short.class)).toArray(new Short[0]);
                    } else if (Integer.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getList(i, Integer.class)).toArray(new Integer[0]);
                    } else if (Long.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getList(i, Long.class)).toArray(new Long[0]);
                    } else if (Float.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getList(i, Float.class)).toArray(new Float[0]);
                    } else if (Double.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getList(i, Double.class)).toArray(new Double[0]);
                    } else if (Boolean.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getList(i, Boolean.class)).toArray(new Boolean[0]);
                    } else {
                        throw new GaussDBConnectorException(CommonErrorCode.UNSUPPORTED_DATA_TYPE,
                                "List unsupported this data type: " + typeClass.toString());
                    }
                    break;
                case ProtocolConstants.DataType.SET:
                    typeClass = null;
                    if (metaData.get(i).getType() instanceof DefaultSetType)
                        typeClass = convert(((DefaultSetType) metaData.get(i).getType()).getElementType())
                                .getTypeClass();
                    if (String.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getSet(i, String.class)).toArray(new String[0]);
                    } else if (Byte.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getSet(i, Byte.class)).toArray(new Byte[0]);
                    } else if (Short.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getSet(i, Short.class)).toArray(new Short[0]);
                    } else if (Integer.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getSet(i, Integer.class)).toArray(new Integer[0]);
                    } else if (Long.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getSet(i, Long.class)).toArray(new Long[0]);
                    } else if (Float.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getSet(i, Float.class)).toArray(new Float[0]);
                    } else if (Double.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getSet(i, Double.class)).toArray(new Double[0]);
                    } else if (Boolean.class.equals(typeClass)) {
                        fields[i] = Objects.requireNonNull(row.getSet(i, Boolean.class)).toArray(new Boolean[0]);
                    } else {
                        throw new GaussDBConnectorException(CommonErrorCode.UNSUPPORTED_DATA_TYPE,
                                "List unsupported this data type: " + typeClass.toString());
                    }
                    break;
                default:
                    fields[i] = row.getObject(i);
            }
        }
        return new SeaTunnelRow(fields);
    }

    /**
     * According to the given data type, the file value is converted to the
     * corresponding type and
     * the converted value is injected into the binding statement.
     *
     * @param statement
     * @param index
     * @param type
     * @param fileValue
     * @return
     */
    public static BoundStatement reconvertAndInject(BoundStatement statement, int index, DataType type,
                                                    Object fileValue) {
        switch (type.getProtocolCode()) {
            case ProtocolConstants.DataType.VARCHAR:
            case ProtocolConstants.DataType.ASCII:
                if (fileValue instanceof String)
                    statement = statement.setString(index, (String) fileValue);
                return statement;
            case ProtocolConstants.DataType.VARINT:
                if (fileValue instanceof String)
                    statement = statement.setBigInteger(index, new BigInteger((String) fileValue));
                return statement;
            case ProtocolConstants.DataType.UUID:
            case ProtocolConstants.DataType.TIMEUUID:
                if (fileValue instanceof String)
                    statement = statement.setUuid(index, UUID.fromString((String) fileValue));
                return statement;
            case ProtocolConstants.DataType.INET:
                if (fileValue instanceof String) {
                    try {
                        statement = statement.setInetAddress(index, InetAddress.getByName((String) fileValue));
                    } catch (UnknownHostException e) {
                        throw new GaussDBConnectorException(GaussDBCassandraErrorCode.PARSE_IP_ADDRESS_FAILED, e);
                    }
                }
                return statement;
            case ProtocolConstants.DataType.TINYINT:
                if (fileValue instanceof Byte)
                    statement = statement.setByte(index, (Byte) fileValue);
                return statement;
            case ProtocolConstants.DataType.SMALLINT:
                if (fileValue instanceof Short)
                    statement = statement.setShort(index, (Short) fileValue);
                return statement;
            case ProtocolConstants.DataType.INT:
                if (fileValue instanceof Integer)
                    statement = statement.setInt(index, (Integer) fileValue);
                return statement;
            case ProtocolConstants.DataType.BIGINT:
            case ProtocolConstants.DataType.COUNTER:
                if (fileValue instanceof Long)
                    statement = statement.setLong(index, (Long) fileValue);
                return statement;
            case ProtocolConstants.DataType.FLOAT:
                if (fileValue instanceof Float)
                    statement = statement.setFloat(index, (Float) fileValue);
                return statement;
            case ProtocolConstants.DataType.DOUBLE:
                if (fileValue instanceof Double)
                    statement = statement.setDouble(index, (Double) fileValue);
                return statement;
            case ProtocolConstants.DataType.DECIMAL:
                if (fileValue instanceof Double)
                    statement = statement.setBigDecimal(index, BigDecimal.valueOf((Double) fileValue));
                return statement;
            case ProtocolConstants.DataType.BOOLEAN:
                if (fileValue instanceof Boolean)
                    statement = statement.setBoolean(index, (Boolean) fileValue);
                return statement;
            case ProtocolConstants.DataType.TIME:
                if (fileValue instanceof LocalTime)
                    statement = statement.setLocalTime(index, (LocalTime) fileValue);
                return statement;
            case ProtocolConstants.DataType.DATE:
                if (fileValue instanceof LocalDate)
                    statement = statement.setLocalDate(index, (LocalDate) fileValue);
                return statement;
            case ProtocolConstants.DataType.TIMESTAMP:
                if (fileValue instanceof LocalDateTime)
                    statement = statement.setInstant(index,
                            ((LocalDateTime) fileValue).atZone(ZoneId.systemDefault()).toInstant());
                return statement;
            case ProtocolConstants.DataType.BLOB:
                if (fileValue instanceof Byte[]) {
                    Byte[] fileValueArray = (Byte[]) fileValue;
                    byte[] primitiveArray = new byte[fileValueArray.length];
                    System.arraycopy(fileValueArray, 0, primitiveArray, 0, fileValueArray.length);
                    statement = statement.setByteBuffer(index, ByteBuffer.wrap(primitiveArray));
                }
                return statement;
            case ProtocolConstants.DataType.MAP:
                if (fileValue instanceof Map)
                    statement = statement.set(index, (Map) fileValue, Map.class);
                return statement;
            case ProtocolConstants.DataType.LIST:
                if (fileValue instanceof Object[])
                    statement = statement.set(index, Arrays.asList((Object[]) fileValue), List.class);
                return statement;
            case ProtocolConstants.DataType.SET:
                if (fileValue instanceof Object[])
                    statement = statement.set(index, new HashSet<>(Arrays.asList((Object[]) fileValue)), Set.class);
                return statement;
            default:
                statement = statement.set(index, fileValue, Object.class);
                return statement;
        }
    }
}
